# -*- coding: utf-8 -*-
'''
Created on 2016年12月29日

@author: ZhuJiahui
'''

import random
import numpy as np


class TransPreparaion(object):
    
    def __init__(self, entity_list, relation_list, train_triples, validate_triples=[], test_triples=[], negative_type='bern'):
        self.entity_list = entity_list
        self.relation_list = relation_list
        self.train_triples = train_triples
        self.validate_triples = validate_triples
        self.test_triples = test_triples
        self.negative_type = negative_type
        self.entity_num = len(entity_list)
        self.relation_num = len(relation_list)
        print('Negative type is ', self.negative_type)
        
        self.er_to_id()
        self.prepare_data()
    
    def er_to_id(self):
        '''
        实体和关系预处理
        构建映射关系字典
        '''
        
        # 定义实体和id之间的双向映射字典(str, int)
        self.entity2id_dict, self.id2entity_dict = dict(), dict()
        # 定义关系和id之间的双向映射字典(str, int)
        self.relation2id_dict, self.id2relation_dict = dict(), dict()
        
        for ent in self.entity_list:
            # 实体id从0开始自增
            self.entity2id_dict[ent] = len(self.entity2id_dict)
            self.id2entity_dict[len(self.id2entity_dict)] = ent
        for rel in self.relation_list:
            # 关系id从0开始自增
            self.relation2id_dict[rel] = len(self.relation2id_dict)
            self.id2entity_dict[len(self.id2relation_dict)] = rel

        # 实体和关系的数量
        self.entity_num, self.relation_num = len(self.entity2id_dict), len(self.relation2id_dict)
        print("entity num : " + str(self.entity_num))
        print("relation num : " + str(self.relation_num))

    def read_triple(self, triples):
        '''
        将三元组由字符串映射成id,形成[(h,r,t)]型list
        :param triples: 三元组(2d str list)
        :return: [((int)h,(int)r,(int)t)]型list
        '''
        return [(self.entity2id_dict[x[0]], self.relation2id_dict[x[1]], self.entity2id_dict[x[2]]) for x in triples]

    def add_dict_kv(self, dic, relation, new_entity):
        '''
        将new_entity加入到dic中键为relation的值的集合中
        :param dic: 字典
        :param relation: 关系
        :param new_entity: 新加入的实体的id
        '''
        vs = dic.get(relation, set())
        vs.add(new_entity)
        dic[relation] = vs
       
    def add_dict_kkv(self, dic, relation, entity1, entity2):
        '''
        将entity1和entity2以键值对的形式加入(合并)到dic中
        该dic键为relation,值也为字典
        :param dic: 整个字典
        :param relation: 关系
        :param entity1: 实体1
        :param entity2: 实体2
        '''
        k2vs = dic.get(relation, dict())  # 获取relation对应的实体集合,返回的结果仍是字典,key是entity1
        vs = k2vs.get(entity1, set())  # 获取当前和entity1有此关系的所有目标实体集合
        vs.add(entity2)
        k2vs[entity1] = vs
        dic[relation] = k2vs
    
    def get_tph_hpt(self, r_h_t_dict, r_t_h_dict):
        '''
        获取特定关系下平均实体数量字典
        :param r_h_t_dict: 关系->头实体->尾实体id映射字典
        :param r_t_h_dict: 关系->尾实体->头实体id映射字典
        :return:
            字典为{int : (float64, float64)}型
            key为关系的id,从0开始
            value的第一个值为在该关系下,平均每个h(头实体)有多少个不同的t(尾实体),第二个值为平均每个t有多少个不同的h
        '''
        
        relation2tphhpt = dict()
        
        for each_relation in r_h_t_dict.keys():
            
            total_tail_count = 0
            for each_head in r_h_t_dict[each_relation]:
                total_tail_count += len(r_h_t_dict[each_relation][each_head])
            this_tph = np.true_divide(total_tail_count, len(r_h_t_dict[each_relation].keys()))
            
            total_head_count = 0
            for each_tail in r_t_h_dict[each_relation]:
                total_head_count += len(r_t_h_dict[each_relation][each_tail])
            this_hpt = np.true_divide(total_head_count, len(r_t_h_dict[each_relation].keys()))
            
            relation2tphhpt[each_relation] = (this_tph, this_hpt)
        
        return relation2tphhpt
    
    def negative_sample(self, id_triples, negative_type):
        '''
        负采样产生负例数据
        :param id_triples: [(h, r, t)] list
        :return 负例三元组
        '''
        
        negative_id_triples = list()
        for (h, r, t) in id_triples:
            h2, r2, t2 = h, r, t
            replicated_num = 0
            head_prop = 500
            origin_head_prop = 500
            
            if negative_type == 'bern':
                tph, hpt = self.relation2tphhpt[r]
                head_prop = int(2.0 * tph * head_prop / (tph + hpt))
            
            while True:
                if np.random.randint(0, (2 * origin_head_prop - 1)) < head_prop:
                    if negative_type == 'bern':
                        h2 = random.sample(self.train_r_h_dict[r], 1)[0]
                    else:
                        h2 = np.random.randint(0, self.entity_num - 1)
                else:
                    if negative_type == 'bern':
                        t2 = random.sample(self.train_r_t_dict[r], 1)[0]
                    else:
                        t2 = np.random.randint(0, self.entity_num - 1)
                
                if (h2, r2, t2) not in id_triples:
                    break
                else:
                    replicated_num += 1
                    if replicated_num > 10:
                        break
            
            if replicated_num < 10:
                negative_id_triples.append((h2, r2, t2))
            else:
                if np.random.randint(0, (2 * origin_head_prop - 1)) < head_prop:
                    h2 = np.random.randint(0, self.entity_num - 1)
                else:
                    t2 = np.random.randint(0, self.entity_num - 1)
                negative_id_triples.append((h2, r2, t2))
            
        return negative_id_triples

    def prepare_data(self):
        '''
        实体和关系字典构建
        为各个字典添加三元组信息
        '''
        # 读取用于训练的事实三元组形成id三元组
        self.train_id_triples = self.read_triple(self.train_triples)
        # self.train_triple_set = set(self.train_id_triples)  # 变成集合,便于增删交并
        # 读取用于验证的事实三元组形成id三元组
        self.validate_id_triples = self.read_triple(self.validate_triples)
        # 读取用于测试的事实三元组形成id三元组
        self.test_id_triples = self.read_triple(self.test_triples)

        # 获取所有数据中的头实体和尾实体id
        # self.heads = set([x[0] for x in self.train_id_triples + self.validate_id_triples + self.test_id_triples])
        # self.tails = set([x[2] for x in self.train_id_triples + self.validate_id_triples + self.test_id_triples])
        # 训练集中的关系->头实体id映射字典 和 关系->尾实体映射字典
        self.train_r_h_dict, self.train_r_t_dict = dict(), dict()
        # 所有数据中的关系->头实体id映射字典 和 关系->尾实体映射字典
        # self.all_r_h_dict, self.all_r_t_dict = dict(), dict()
        # 训练数据中的关系->头实体->尾实体id映射字典 和 关系->尾实体->头实体id映射字典
        self.train_r_h_t_dict, self.train_r_t_h_dict = dict(), dict()
        # 所有数据中的关系->头实体->尾实体id映射字典 和 关系->尾实体->头实体id映射字典
        # self.all_r_h_t_dict, self.all_r_t_h_dict = dict(), dict()
        
        # 训练数据处理
        for (h, r, t) in self.train_id_triples:
            self.add_dict_kv(self.train_r_h_dict, r, h)
            self.add_dict_kv(self.train_r_t_dict, r, t)
            # self.add_dict_kv(self.all_r_h_dict, r, h)
            # self.add_dict_kv(self.all_r_t_dict, r, t)
            self.add_dict_kkv(self.train_r_h_t_dict, r, h, t)
            self.add_dict_kkv(self.train_r_t_h_dict, r, t, h)
            # self.add_dict_kkv(self.all_r_h_t_dict, r, h, t)
            # self.add_dict_kkv(self.all_r_t_h_dict, r, t, h)
            
        # 验证数据处理
        # for (h, r, t) in self.validate_id_triples + self.test_id_triples:
            # self.add_dict_kv(self.all_r_h_dict, r, h)
            # self.add_dict_kv(self.all_r_t_dict, r, t)
            # self.add_dict_kkv(self.all_r_h_t_dict, r, h, t)
            # self.add_dict_kkv(self.all_r_t_h_dict, r, t, h)
            
        # 字典为{int : (float64, float64)}型
        # key为关系的id,从0开始
        # value的第一个值为在该关系下,平均每个h(头实体)有多少个不同的t(尾实体),第二个值为平均每个t有多少个不同的h
        self.relation2tphhpt = self.get_tph_hpt(self.train_r_h_t_dict, self.train_r_t_h_dict)
        
        # 验证集和测试集负采样
        # self.validate_id_triples_negative = self.negative_sample(self.validate_id_triples, self.negative_type)
        # self.test_id_triples_negative = self.negative_sample(self.test_id_triples, self.negative_type)

    def generate_batch(self, id_triples, batch_size, negative_type):
        '''
        产生训练数据片 包括负采样
        :param id_triples: [(h, r, t)] list
        :param batch_size: 训练数据片大小
        :param negative_type: 负采样类型
        '''
        # 从训练三元组中随机截取指定长度的片断
        batch_pos = random.sample(id_triples, batch_size)
        batch_neg = self.negative_sample(batch_pos, negative_type)
            
        return batch_pos, batch_neg

    
if __name__ == '__main__':
    pass
